{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Simple Kernel SHAP\n", "\n", "This notebook provides a simple brute force version of Kernel SHAP that enumerates the entire $2^M$ sample space. We also compare to the full KernelExplainer implementation. Note that KernelExplainer does a sampling approximation for large values of $M$, but for small values it is exact." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Brute Force Kernel SHAP" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " reference = [0. 0. 0. 0.]\n", " x = [ 1.62434536 -0.61175641 -0.52817175 -1.07296862]\n", "shap_values = [ 0.89146267 -0.43752168 -0.31836259 -0.58464256]\n", " base_value = 10.000000000000002\n", " sum(phi) = 9.55093584213122\n", " f(x) = 9.55093584213122\n" ] } ], "source": [ "import itertools\n", "\n", "import numpy as np\n", "import scipy.special\n", "\n", "\n", "def powerset(iterable):\n", " s = list(iterable)\n", " return itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(len(s) + 1))\n", "\n", "\n", "def shapley_kernel(M, s):\n", " if s == 0 or s == M:\n", " return 10000\n", " return (M - 1) / (scipy.special.binom(M, s) * s * (M - s))\n", "\n", "\n", "def f(X):\n", " np.random.seed(0)\n", " beta = np.random.rand(X.shape[-1])\n", " return np.dot(X, beta) + 10\n", "\n", "\n", "def kernel_shap(f, x, reference, M):\n", " X = np.zeros((2**M, M + 1))\n", " X[:, -1] = 1\n", " weights = np.zeros(2**M)\n", " V = np.zeros((2**M, M))\n", " for i in range(2**M):\n", " V[i, :] = reference\n", "\n", " ws = {}\n", " for i, s in enumerate(powerset(range(M))):\n", " s = list(s)\n", " V[i, s] = x[s]\n", " X[i, s] = 1\n", " ws[len(s)] = ws.get(len(s), 0) + shapley_kernel(M, len(s))\n", " weights[i] = shapley_kernel(M, len(s))\n", " y = f(V)\n", " wsq = np.sqrt(weights)\n", " result = np.linalg.lstsq(wsq[:, None] * X, wsq * y, rcond=None)[0]\n", " return result\n", "\n", "\n", "M = 4\n", "np.random.seed(1)\n", "x = np.random.randn(M)\n", "reference = np.zeros(M)\n", "phi = kernel_shap(f, x, reference, M)\n", "base_value = phi[-1]\n", "shap_values = phi[:-1]\n", "\n", "print(\" reference =\", reference)\n", "print(\" x =\", x)\n", "print(\"shap_values =\", shap_values)\n", "print(\" base_value =\", base_value)\n", "print(\" sum(phi) =\", np.sum(phi))\n", "print(\" f(x) =\", f(x))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using KernelExplainer" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "shap_values = [ 0.89146267 -0.43752168 -0.31836259 -0.58464256]\n", "base value = 10.0\n" ] } ], "source": [ "import shap\n", "\n", "explainer = shap.KernelExplainer(f, np.reshape(reference, (1, len(reference))))\n", "shap_values = explainer.shap_values(x)\n", "print(\"shap_values =\", shap_values)\n", "print(\"base value =\", explainer.expected_value)" ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 1 }